from typing import Tuple
import math
from gym.envs.registration import register
import numpy as np
from typing import Tuple, Dict, Text
from highway_env import utils
from highway_env.envs.common.abstract import AbstractEnv
from highway_env.road.lane import LineType, StraightLane, CircularLane, SineLane
from highway_env.road.road import Road, RoadNetwork
from highway_env.vehicle.controller import MDPVehicle
from highway_env.envs.common.action import Action
from highway_env.vehicle.controller import ControlledVehicle
import time
import random
import py_trees
from highway_env.vehicle.behavior import IDMVehicle
import random
import py_trees
import time
class AddRandomVehicleBehavior(py_trees.behaviour.Behaviour):
    def __init__(self, env, name: str = "Add Random Vehicle"):
        super(AddRandomVehicleBehavior, self).__init__(name)
        self.env = env

    def update(self) -> py_trees.common.Status:
        self.env.add_random_vehicle()
        return py_trees.common.Status.SUCCESS

class TimerBehavior(py_trees.behaviour.Behaviour):
    def __init__(self, duration: float, name: str = "Timer"):
        super(TimerBehavior, self).__init__(name)
        self.duration = duration
        self.start_time = None

    def initialise(self) -> None:
        self.start_time = time.time()

    def update(self) -> py_trees.common.Status:
        if time.time() - self.start_time >= self.duration:
            return py_trees.common.Status.SUCCESS
        return py_trees.common.Status.RUNNING



## 改成离散的 再训练
class RoundaboutEnv(AbstractEnv):

    @classmethod
    def default_config(cls) -> dict:
        config = super().default_config()
        config.update({
            "observation": {
                # "type": "Kinematics",
                # "vehicles_count": 6,
                # # "horizon": 5,
                # "absolute": False,
                # "features": ['presence', 'x', 'y', 'vx', 'vy'],
                # # "lane_heading_difference"   # "cos_h", "sin_h",
                # # "normalize": False,
                # "features_range": {"x": [-100, 100], "y": [-100, 100], "vx": [-25, 25], "vy": [-15, 15]},

                "type": "OccupancyGrid",
                # "grid_size": [[0, 100], [0, 100]],  # 设置网格尺寸
                "features": ["presence"],  # 设置观察空间特征
                # "normalize": False
            },
            "action": {
                # #"target_speeds": [0, 2, 4]
                # "type": "ContinuousAction",
                "type":"DiscreteMetaAction",
                # "speed_range": [5, 8],
                "target_speeds": [1, 3, 5],
            },
            "incoming_vehicle_destination": None,
            # "collision_reward": -1,
            "high_speed_reward": 0.4,
            "right_lane_reward": 0.1,
            # "policy_frequency":5,
            "lane_change_reward": -0.05,
            "screen_width": 1000,
            "screen_height": 600,
            "centering_position": [0.5, 0.6],
            "duration": 65,
            "normalize_reward": True,
            "reward_speed_range": [0, 8],
            "offroad_terminal": False
            # "lane_centering_cost": 1,
            # "lane_centering_reward": 1,
            # "lane_centering_weight": 1,
            # "speed_weight":1,
            # "progress_weight": 1,

        })
        return config

    def _reward(self, action: Action) -> float:
        """
        The reward is defined to foster driving at high speed, on the rightmost lanes, and to avoid collisions.
        :param action: the last action performed
        :return: the corresponding reward
        """
        rewards = self._rewards(action)
        # reward = sum(self.config.get(name, 0) * reward for name, reward in rewards.items())
        reward = (rewards["collision"]*-3 + rewards["high_speed_reward"]*5 + rewards["Lane_change_reward"]*-2)
        #
        # reward *= rewards['on_road_reward']
        return reward

    def _rewards(self, action: Action) -> Dict[Text, float]:
        neighbours = self.road.network.all_side_lanes(self.vehicle.lane_index)
        lane = self.vehicle.target_lane_index[2] if isinstance(self.vehicle, ControlledVehicle) \
            else self.vehicle.lane_index[2]

        forward_speed = self.vehicle.speed * np.cos(self.vehicle.heading)
        scaled_speed = utils.lmap(forward_speed, self.config["reward_speed_range"], [0, 1])

        lane_change = action == 0 or action == 2

        # 效率奖励

        speed_limit = 3
        speed_max = 5
        if self.vehicle.speed <= speed_limit:
            efficiency_reward = -1 / 2 + self.vehicle.speed / 2
        else:
            efficiency_reward = 5 / 2 - self.vehicle.speed / 2
        return {
            "collision": float(self.vehicle.crashed),
            # "right_lane_reward": lane / max(len(neighbours) - 1, 1),
            "high_speed_reward": efficiency_reward,
            # "on_road_reward": float(self.vehicle.on_road),
            "Lane_change_reward": lane_change
        }


    def _is_terminated(self) -> bool:
        return self.vehicle.crashed or \
            (self.config["offroad_terminal"] and not self.vehicle.on_road)

    def _is_truncated(self) -> bool:
        return self.time >= self.config["duration"] or self.vehicle.crashed

    def _reset(self) -> None:
        self._make_road()
        self._make_vehicles()

    def _make_road(self) -> None:
        # Circle lanes: (s)outh/(e)ast/(n)orth/(w)est (e)ntry/e(x)it.
        center = [0, 0]  # [m]
        radius = 20  # [m]
        alpha = 24  # [deg]

        net = RoadNetwork()
        radii = [radius, radius + 4]
        n, c, s = LineType.NONE, LineType.CONTINUOUS, LineType.STRIPED
        line = [[c, s], [n, c]]
        for lane in [0, 1]:
            net.add_lane("se", "ex",
                         CircularLane(center, radii[lane], np.deg2rad(90 - alpha), np.deg2rad(alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("ex", "ee",
                         CircularLane(center, radii[lane], np.deg2rad(alpha), np.deg2rad(-alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("ee", "nx",
                         CircularLane(center, radii[lane], np.deg2rad(-alpha), np.deg2rad(-90 + alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("nx", "ne",
                         CircularLane(center, radii[lane], np.deg2rad(-90 + alpha), np.deg2rad(-90 - alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("ne", "wx",
                         CircularLane(center, radii[lane], np.deg2rad(-90 - alpha), np.deg2rad(-180 + alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("wx", "we",
                         CircularLane(center, radii[lane], np.deg2rad(-180 + alpha), np.deg2rad(-180 - alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("we", "sx",
                         CircularLane(center, radii[lane], np.deg2rad(180 - alpha), np.deg2rad(90 + alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("sx", "se",
                         CircularLane(center, radii[lane], np.deg2rad(90 + alpha), np.deg2rad(90 - alpha),
                                      clockwise=False, line_types=line[lane]))

        # Access lanes: (r)oad/(s)ine
        access = 170  # [m]
        dev = 85  # [m]
        a = 5  # [m]
        delta_st = 0.2 * dev  # [m]

        delta_en = dev - delta_st
        w = 2 * np.pi / dev
        net.add_lane("ser", "ses", StraightLane([2, access], [2, dev / 2], line_types=(s, c)))
        net.add_lane("ses", "se",
                     SineLane([2 + a, dev / 2], [2 + a, dev / 2 - delta_st], a, w, -np.pi / 2, line_types=(c, c)))
        net.add_lane("sx", "sxs",
                     SineLane([-2 - a, -dev / 2 + delta_en], [-2 - a, dev / 2], a, w, -np.pi / 2 + w * delta_en,
                              line_types=(c, c)))
        net.add_lane("sxs", "sxr", StraightLane([-2, dev / 2], [-2, access], line_types=(n, c)))

        net.add_lane("eer", "ees", StraightLane([access, -2], [dev / 2, -2], line_types=(s, c)))
        net.add_lane("ees", "ee",
                     SineLane([dev / 2, -2 - a], [dev / 2 - delta_st, -2 - a], a, w, -np.pi / 2, line_types=(c, c)))
        net.add_lane("ex", "exs",
                     SineLane([-dev / 2 + delta_en, 2 + a], [dev / 2, 2 + a], a, w, -np.pi / 2 + w * delta_en,
                              line_types=(c, c)))
        net.add_lane("exs", "exr", StraightLane([dev / 2, 2], [access, 2], line_types=(n, c)))

        net.add_lane("ner", "nes", StraightLane([-2, -access], [-2, -dev / 2], line_types=(s, c)))
        net.add_lane("nes", "ne",
                     SineLane([-2 - a, -dev / 2], [-2 - a, -dev / 2 + delta_st], a, w, -np.pi / 2, line_types=(c, c)))
        net.add_lane("nx", "nxs",
                     SineLane([2 + a, dev / 2 - delta_en], [2 + a, -dev / 2], a, w, -np.pi / 2 + w * delta_en,
                              line_types=(c, c)))
        net.add_lane("nxs", "nxr", StraightLane([2, -dev / 2], [2, -access], line_types=(n, c)))

        net.add_lane("wer", "wes", StraightLane([-access, 2], [-dev / 2, 2], line_types=(s, c)))
        net.add_lane("wes", "we",
                     SineLane([-dev / 2, 2 + a], [-dev / 2 + delta_st, 2 + a], a, w, -np.pi / 2, line_types=(c, c)))
        net.add_lane("wx", "wxs",
                     SineLane([dev / 2 - delta_en, -2 - a], [-dev / 2, -2 - a], a, w, -np.pi / 2 + w * delta_en,
                              line_types=(c, c)))
        net.add_lane("wxs", "wxr", StraightLane([-dev / 2, -2], [-access, -2], line_types=(n, c)))

        road = Road(network=net, np_random=self.np_random, record_history=self.config["show_trajectories"])
        self.road = road

    def add_random_vehicle(self) -> None:
        # 在这里实现添加随机车辆的逻辑
        entrances = [("ser", "ses", 0), ("eer", "ees", 0), ("ner", "nes", 0), ("wer", "wes", 0)]\

        destinations = ["exr", "sxr", "nxr", "wxr"]
        position_deviation = 2
        other_vehicles_type = utils.class_from_path(self.config["other_vehicles_type"])
        vehicle1 = other_vehicles_type.make_on_lane(self.road,
                                                   ("ser", "ses", 0),
                                                   longitudinal=50 + self.np_random.normal() * position_deviation,
                                                   speed=3)
        destination = self.np_random.choice(destinations)

        vehicle1.plan_route_to(self.np_random.choice(destinations))
        vehicle1.randomize_behavior()
        self.road.vehicles.append(vehicle1)

        vehicle2 = other_vehicles_type.make_on_lane(self.road,
                                                    ("eer", "ees", 0),
                                                    longitudinal=50 + self.np_random.normal() * position_deviation,
                                                    speed=3)
        destination = self.np_random.choice(destinations)

        vehicle2.plan_route_to(self.np_random.choice(destinations))
        vehicle2.randomize_behavior()
        self.road.vehicles.append(vehicle2)

        vehicle3 = other_vehicles_type.make_on_lane(self.road,
                                                    ("ner", "nes", 0),
                                                    longitudinal=50 + self.np_random.normal() * position_deviation,
                                                    speed=4)
        destination = self.np_random.choice(destinations)

        vehicle3.plan_route_to(self.np_random.choice(destinations))
        vehicle3.randomize_behavior()
        self.road.vehicles.append(vehicle3)




    def _make_vehicles(self) -> None:
        """
        Populate a road with several vehicles on the highway and on the merging lane, as well as an ego-vehicle.

        :return: the ego-vehicle
        """
        position_deviation = 2
        speed_deviation = 2

        # Ego-vehicle
        ego_lane = self.road.network.get_lane(("wer", "wes", 0))
        ego_vehicle = self.action_type.vehicle_class(self.road,
                                                     ego_lane.position(100, 0),
                                                     speed=1,
                                                     heading=ego_lane.heading_at(0))
        try:
            ego_vehicle.plan_route_to("exr")
        except AttributeError:
            pass
        self.road.vehicles.append(ego_vehicle)
        self.vehicle = ego_vehicle
        self.controlled_vehicles.append(ego_vehicle)
        # Incoming vehicle

        destinations = ["exr", "sxr", "nxr"]
        other_vehicles_type = utils.class_from_path(self.config["other_vehicles_type"])
        vehicle11 = other_vehicles_type.make_on_lane(self.road,
                                                     ("we", "sx", 1),
                                                     longitudinal=5,
                                                     # + self.np_random.random()*position_deviation,
                                                     speed=3)  # + self.np_random.random() * speed_deviation)

        if self.config["incoming_vehicle_destination"] is not None:
            destination = destinations[self.config["incoming_vehicle_destination"]]
        else:
            destination = self.np_random.choice(destinations)
        vehicle11.plan_route_to(destination)
        vehicle11.randomize_behavior()
        self.road.vehicles.append(vehicle11)

        # # Other vehicles

        for i in list(range(1, 5)) + list(range(-1, 0)):
            vehicle = other_vehicles_type.make_on_lane(self.road,
                                                       ("we", "sx", 1),
                                                       longitudinal=28 * i + self.np_random.normal() * position_deviation,
                                                       speed=3+random.random())
                                                             #+ self.np_random.normal() * speed_deviation)
            vehicle.plan_route_to(self.np_random.choice(destinations))
            vehicle.randomize_behavior()
            self.road.vehicles.append(vehicle)

        # Entering vehicle

        vehicle = other_vehicles_type.make_on_lane(self.road,
                                                   ("ser", "ses", 0),
                                                   longitudinal=50 + self.np_random.normal() * position_deviation,
                                                   speed=3.2 )
                                                         #+ self.np_random.normal() * speed_deviation)
        vehicle.plan_route_to(self.np_random.choice(destinations))
        vehicle.randomize_behavior()
        self.road.vehicles.append(vehicle)

        # vehicle = other_vehicles_type.make_on_lane(self.road,
        #                                            ("eer", "ees", 0),
        #                                            longitudinal=50 + self.np_random.normal() * position_deviation,
        #                                            speed=3)
        # # + self.np_random.normal() * speed_deviation)
        # vehicle.plan_route_to(self.np_random.choice(destinations))
        # vehicle.randomize_behavior()
        # self.road.vehicles.append(vehicle)

        vehicle = other_vehicles_type.make_on_lane(self.road,
                                                   ("ner", "nes", 0),
                                                   longitudinal=50 + self.np_random.normal() * position_deviation,
                                                   speed=3)
        # + self.np_random.normal() * speed_deviation)
        vehicle.plan_route_to(self.np_random.choice(destinations))
        vehicle.randomize_behavior()
        self.road.vehicles.append(vehicle)

        # 创建一个行为树
        behavior_tree = py_trees.composites.Sequence("Behavior Tree")

        # 创建一个定时器行为节点，每隔5秒执行一次
        timer = TimerBehavior(duration=0.3)

        # 创建一个自定义行为节点，用于添加随机车辆
        add_vehicle_behavior = AddRandomVehicleBehavior(env=self)

        # 将定时器和自定义行为添加到行为树中
        behavior_tree.add_children([timer, add_vehicle_behavior])

        # 将行为树添加到环境中
        ego_vehicle.behavior_tree = behavior_tree

        # 添加车辆到环境
        self.road.vehicles.append(ego_vehicle)

    register(
        id='roundabout-v0',
        entry_point='highway_env.envs:RoundaboutEnv',
    )





